import cv2
import numpy as np
import h5py
import os
import glob


def list_to_npy_trainval(list_path, npy_path_im, npy_path_la, resize):
    train_folder_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    f_list = open(list_path, 'r')
    img_arrs = []
    labels = []
    for line in f_list:
        sub_path, label = line.split(' ')
        img_path = train_folder_path + '/' + sub_path
        img_arr = cv2.imread(img_path)
        img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
        img_arr = img_arr[80:400, 320:]
        img_arr = img_arr[:, 80:560]
        img_arr = cv2.resize(img_arr, (resize, resize), interpolation=cv2.INTER_AREA)
        img_arrs.append(img_arr)
        labels.append(label)
    img_arrs = np.asarray(img_arrs, np.uint8)
    labels = np.asarray(labels, np.uint8)
    np.save(npy_path_im, img_arrs)
    np.save(npy_path_la, labels)


mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
reshaped_mean_vec = mean_vec.reshape(3, 1, 1)


def list_to_h5_trainval(list_path, h5_path, resize):
    train_folder_path = '/media/dell/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/train'
    #train_folder_path = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/train'
    # train_folder_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    f_list = open(list_path, 'r')
    img_arrs = []
    labels = []
    for line in f_list:
        sub_path, label = line.split(' ')
        img_path = train_folder_path + '/' + sub_path
        img_arr = cv2.imread(img_path)
        # img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
        img_arr = img_arr[80:400, 320:]
        img_arr = img_arr[:, 80:560]
        img_arr = cv2.resize(img_arr, (resize, resize), interpolation=cv2.INTER_AREA) / 256
        img_arrs.append(img_arr)
        labels.append(label)
    f = h5py.File(h5_path)
    img_arrs = np.asarray(img_arrs, np.float32)
    img_arrs = img_arrs.transpose((0, 3, 1, 2))
    labels = np.asarray(labels, np.uint8)
    labels = labels.reshape((len(labels), 1))
    print img_arrs.shape
    print labels.shape
    # img_arrs = img_arrs.transpose((0, 3, 1, 2))  # input shape must be [batch_size, H, W, C]
    img_arrs = img_arrs[:] - reshaped_mean_vec
    f.create_dataset('data', data=img_arrs)
    f.create_dataset('label', data=labels)
    f.close()


def crop_resize_img_from_list(list_path, resize, new_list_path):
    #train_folder_path = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/train'
    train_folder_path = '/media/dell/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/train'
    img_root='/media/dell/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs'
    # train_folder_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    f_list = open(list_path, 'r')
    img_arrs = []
    labels = []
    crop = []
    subsubpath = ''
    f = open(new_list_path, 'w')
    cnt=0
    for line in f_list:
        sub_path, label = line.split(' ')
        img_path = train_folder_path + '/' + sub_path
        img_arr = cv2.imread(img_path)
        # img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
        # img_arr = img_arr[80:400, 320:]
        img_arr = img_arr[:, 80:560]
        img_arr = cv2.resize(img_arr, (resize, resize), interpolation=cv2.INTER_AREA)
        crop.append(img_arr[:224, :224])
        crop.append(img_arr[:224, 256 - 224:256])
        crop.append(img_arr[256 - 224:256, :224])
        crop.append(img_arr[256 - 224:256, 256 - 224:256])
        crop.append(img_arr[128-112:128+112,128-112:128+112])
        crop.append(img_arr[128-112:128+112, :224])
        crop.append(img_arr[128 - 112:128+112, 256-224:256])
        crop.append(img_arr[:224, 128-112:128+112])
        crop.append(img_arr[256-224:256,128-112:128+112])        
        #print 'crop shape',np.asarray(crop).shape
        img_name = sub_path.split('/')[-1]
        if 'val' in list_path:
            subsubpath = 'val_cropped'
        elif 'train' in list_path:
            subsubpath = 'train_cropped'
        save_path = os.path.join(img_root,'cropped224',subsubpath, str(img_name))
        line_path1 = os.path.join('cropped224',subsubpath, str(img_name))
        new_line = line_path1 + ' ' + str(label)
        f.writelines(new_line)
        cv2.imwrite(save_path, img_arr)
        for i in range(np.asarray(crop).shape[0]):
            crop_save_path = os.path.join(img_root,'cropped224',subsubpath, str(img_name.split('.')[0]) + '_crop_' + str(i)+'.jpg')
            line_path = os.path.join('cropped224',subsubpath, str(img_name.split('.')[0]) + '_crop_' + str(i)+'.jpg')
            crop_line = line_path + ' ' + str(label)
            f.writelines(crop_line)
            cv2.imwrite(crop_save_path, crop[i])
            #print crop_save_path
        crop=[]
        cnt+=1
        if cnt%1000==0:
            print '{0}'.format(cnt)
    f.close()

data_txt='/media/dell/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/fine-tune-caffe/data_txt/'
crop_resize_img_from_list(
    list_path=data_txt+'driver_sorted/train_list.txt',
    resize=256,
    new_list_path=data_txt+'driver_sorted/25vs1/multi_crop/train_list_9cropped.txt'
)

crop_resize_img_from_list(
    list_path=data_txt+'driver_sorted/val_list.txt',
    resize=256,
    new_list_path=data_txt+'driver_sorted/25vs1/multi_crop/val_list_9cropped.txt'
)
# test
# not understand
def oversample(images, crop_dims):
    """
    Crop images into the four corners, center, and their mirrored versions.

    Parameters
    ----------
    image : iterable of (N x H x W x K) ndarrays
    crop_dims : (height, width) tuple for the crops.

    Returns
    -------
    crops : (10*N x H x W x K) ndarray of crops for number of inputs N.
    """
    # Dimensions and center.
    images = np.reshape(images, (1, 256, 256, 3))
    im_shape = np.array(images[0].shape)
    crop_dims = np.array(crop_dims)
    im_center = im_shape[:2] / 2.0
    print im_center

    # Make crop coordinates
    h_indices = (0, im_shape[0] - crop_dims[0])
    w_indices = (0, im_shape[1] - crop_dims[1])
    crops_ix = np.empty((5, 4), dtype=int)
    curr = 0
    for i in h_indices:
        for j in w_indices:
            crops_ix[curr] = (i, j, i + crop_dims[0], j + crop_dims[1])
            curr += 1
    crops_ix[4] = np.tile(im_center, (1, 2)) + np.concatenate([
        -crop_dims / 2.0,
        crop_dims / 2.0
    ])
    crops_ix = np.tile(crops_ix, (2, 1))

    # Extract crops
    crops = np.empty((10 * len(images), crop_dims[0], crop_dims[1],
                      im_shape[-1]), dtype=np.float32)
    print crops.shape
    ix = 0
    for im in images:
        for crop in crops_ix:
            crops[ix] = im[crop[0]:crop[2], crop[1]:crop[3], :]
            ix += 1
        crops[ix - 5:ix] = crops[ix - 5:ix, :, ::-1, :]  # flip for mirrors
    return crops


'''
test_img='/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/train/c0/img_34.jpg'
resize=256
crop_dim=(224,224)
img_arr=cv2.imread(test_img)
img_arr = img_arr[:, 80:560]
#print img_arr.shape
img_arr = cv2.resize(img_arr, (resize, resize), interpolation=cv2.INTER_AREA)
crop_1=img_arr[:224,:224]
crop_2=img_arr[:224,256-224:256]
crop_3=img_arr[256-224:256,:224]
crop_4=img_arr[256-224:256,256-224:256]
print crop_1.shape
print crop_2.shape
print crop_3.shape
print crop_4.shape
'''

"""
list_to_h5_trainval(
    list_path='/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/data_txt/driver_sorted/val_list.txt',
    h5_path='/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/data_h5/val_caffe_RGB_mean_sub.h5',
    resize=224)
list_to_h5_trainval(
    list_path='/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/data_txt/driver_sorted/train_list.txt',
    h5_path='/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/data_h5/train_caffe_RGB_mean_sub.h5',
    resize=224)
"""
